"""Export inference graph."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import sys

import tensorflow as tf

import deeplab_model
from utils import preprocessing

parser = argparse.ArgumentParser()

parser.add_argument('--model_dir', type=str, default='./model',
                    help="Base directory for the model. "
                         "Make sure 'model_checkpoint_path' given in 'checkpoint' file matches "
                         "with checkpoint name.")

parser.add_argument('--export_dir', type=str, default='dataset/export_output',
                    help='The directory where the exported SavedModel will be stored.')

parser.add_argument('--base_architecture', type=str, default='resnet_v2_101',
                    choices=['resnet_v2_50', 'resnet_v2_101'],
                    help='The architecture of base Resnet building block.')

parser.add_argument('--output_stride', type=int, default=16,
                    choices=[8, 16],
                    help='Output stride for DeepLab v3. Currently 8 or 16 is supported.')

_NUM_CLASSES = 21


def main(unused_argv):
    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    model = tf.estimator.Estimator(
        model_fn=deeplab_model.deeplabv3_plus_model_fn,
        model_dir=FLAGS.model_dir,
        params={
            'output_stride': FLAGS.output_stride,
            'batch_size': 1,  # Batch size must be 1 because the images' size may differ
            'base_architecture': FLAGS.base_architecture,
            'pre_trained_model': None,
            'batch_norm_decay': None,
            'num_classes': _NUM_CLASSES,
        })

    # Export the model
    def serving_input_receiver_fn():
        image = tf.placeholder(tf.float32, [None, None, None, 3], name='image_tensor')
        receiver_tensors = {'image': image}
        features = tf.map_fn(preprocessing.mean_image_subtraction, image)
        return tf.estimator.export.ServingInputReceiver(
            features=features,
            receiver_tensors=receiver_tensors)

    model.export_savedmodel(FLAGS.export_dir, serving_input_receiver_fn)


if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
